package mybatis.generator.support.generator;

import mybatis.generator.FilterType;
import mybatis.generator.MybatisGenerator;
import mybatis.generator.annotation.ExecuteUpdate;
import mybatis.generator.annotation.GeneratedSql;
import mybatis.generator.ext.InterceptorContext;
import mybatis.generator.support.Reflection;
import mybatis.generator.support.StringUtils;
import mybatis.generator.support.parse.*;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.reflection.ParamNameResolver;
import org.apache.ibatis.scripting.xmltags.XMLLanguageDriver;
import org.apache.ibatis.session.Configuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Method;
import java.util.*;
import java.util.stream.Collectors;

/**
 * Created by huangdachao on 2018/6/14 16:59.
 */
public class UpdateGenerator {
    private static final Logger LOG = LoggerFactory.getLogger(UpdateGenerator.class);

    public static SqlSource generate(Configuration conf, Class<?> mapperClass, Method method) {
        EntityMeta em = EntityMeta.get(Reflection.getEntityClass(mapperClass));

        return parameterObject -> {
            Object[] args = Reflection.parseArgs(conf, method, parameterObject);
            String[] argNames = new ParamNameResolver(conf, method).getNames();
            InterceptorContext context = new InterceptorContext(mapperClass, em.getEntityClass(), method, args, GeneratedSql.Type.UPDATE);
            MybatisGenerator.intercept(context);

            List<ColumnValuePair> cvp = ColumnValuePair.parseArgs(em, method, args, argNames);
            ColumnValuePair.mergeFromColumnValues(method.getAnnotation(ExecuteUpdate.class).columnValue(), em, cvp);
            Map<String, Object> additionalParams = new HashMap<>();
            ColumnValuePair.mergeFromInterceptorContext(context, additionalParams, cvp);
            if (cvp.isEmpty()) {
                throw new IllegalArgumentException("没有有效的可更新字段" + Reflection.formatMethod(mapperClass, method));
            }

            Set<Table> tables = cvp.stream().flatMap(c -> c.getReferTables().stream()).collect(Collectors.toSet());
            FromAndWhereGenerator faw = new FromAndWhereGenerator(em, method, args, argNames, context, tables);
            // 将列值对中的主键字段放入过滤条件中
            if (method.getParameterCount() == 1 && args[0] != null && args[0].getClass() == em.getEntityClass()) {
                List<FilterMeta> filterMeta = new ArrayList<>();
                em.getPrimaryColumns().forEach(pk -> {
                    FilterMeta fm = new FilterMeta();
                    fm.setTableColumn(pk);
                    fm.setArgIndex(0);
                    fm.setField(em.getFieldMeta().get(pk).getField().getName());
                    filterMeta.add(fm);
                });

                if (filterMeta.size() > 0) {
                    filterMeta.addAll(faw.getFilterMeta());
                    faw.setFilterMeta(filterMeta);
                }
            }

            StringBuilder sql = new StringBuilder("<script>update ").append(em.getTable().table).append(" t0 ")
                .append(faw.generateJoinClause()).append(" set ");
            {
                boolean first = true;
                for (ColumnValuePair c : cvp) {
                    if (first) {
                        first = false;
                    } else {
                        sql.append(",");
                    }
                    sql.append("t0.").append(c.getColumn()).append(FilterType.eq.toSql(
                            StringUtils.normalizeTableAndFieldExpression(c.getPlaceholder(), em, faw.getTableAliasMap()), null));
                }
                sql.append(" ");
            }
            sql.append(faw.generateWhereClause()).append("</script>");

            String sqlStr = sql.toString();
            LOG.info(sqlStr);
            Class<?> parameterType = parameterObject == null ? Object.class : parameterObject.getClass();
            SqlSource sqlSource = conf.getLanguageRegistry().getDriver(XMLLanguageDriver.class)
                .createSqlSource(conf, sqlStr, parameterType);
            BoundSql bsql = sqlSource.getBoundSql(parameterObject);
            faw.getAdditionalParams().forEach(bsql::setAdditionalParameter);
            additionalParams.forEach(bsql::setAdditionalParameter);
            return bsql;
        };
    }

}
